#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from npu_bridge.npu_init import *
import os
import tensorflow as tf
import argparse
import shutil
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_path', '/home/ma-user/modelarts/inputs/data_url_0/',
                           """Root directory of data""")
tf.app.flags.DEFINE_string('output_path', '/home/ma-user/modelarts/outputs/train_url_0/',
                           """Directory where to write event logs and checkpoint. """)
tf.app.flags.DEFINE_integer("epochs",50,""" epochs of training""")


def train(net, data):
    for ep in range(0, 51, 1):
        l_rate = max(0.0005 / (2 ** (ep // 20)), 0.00001)

        data.shuffle_train_files(ep)
        total_train_batch_num =30#data.total_train_batch_num
        print('total train batch num:', total_train_batch_num)
        for i in range(total_train_batch_num):
            ###### training
            bat_pc, _, _, bat_psem_onehot, bat_bbvert, bat_pmask = data.load_train_next_batch()
            _, ls_psemce, ls_bbvert_all, ls_bbvert_l2, ls_bbvert_ce, ls_bbvert_iou, ls_bbscore, ls_pmask = net.sess.run(
                [
                    net.optim, net.psemce_loss, net.bbvert_loss, net.bbvert_loss_l2, net.bbvert_loss_ce,
                    net.bbvert_loss_iou, net.bbscore_loss, net.pmask_loss],
                feed_dict={net.X_pc: bat_pc[:, :, 0:9], net.Y_bbvert: bat_bbvert, net.Y_pmask: bat_pmask,
                           net.Y_psem: bat_psem_onehot, net.lr: l_rate, net.is_train: True})
            if i % 200 == 0:
                sum_train = net.sess.run(net.sum_merged,
                                         feed_dict={net.X_pc: bat_pc[:, :, 0:9], net.Y_bbvert: bat_bbvert,
                                                    net.Y_pmask: bat_pmask, net.Y_psem: bat_psem_onehot, net.lr: l_rate,
                                                    net.is_train: False})
                net.sum_writer_train.add_summary(sum_train, ep * total_train_batch_num + i)
            print('ep', ep, 'i', i, 'psemce', ls_psemce, 'bbvert', ls_bbvert_all, 'l2', ls_bbvert_l2, 'ce',
                  ls_bbvert_ce, 'siou', ls_bbvert_iou, 'bbscore', ls_bbscore, 'pmask', ls_pmask)

            ###### random testing
            if i % 200 == 0:
                bat_pc, _, _, bat_psem_onehot, bat_bbvert, bat_pmask = data.load_test_next_batch_random()
                ls_psemce, ls_bbvert_all, ls_bbvert_l2, ls_bbvert_ce, ls_bbvert_iou, ls_bbscore, ls_pmask, sum_test, pred_bborder = net.sess.run(
                    [
                        net.psemce_loss, net.bbvert_loss, net.bbvert_loss_l2, net.bbvert_loss_ce, net.bbvert_loss_iou,
                        net.bbscore_loss, net.pmask_loss, net.sum_merged, net.pred_bborder],
                    feed_dict={net.X_pc: bat_pc[:, :, 0:9], net.Y_bbvert: bat_bbvert, net.Y_pmask: bat_pmask,
                               net.Y_psem: bat_psem_onehot, net.is_train: False})
                net.sum_write_test.add_summary(sum_test, ep * total_train_batch_num + i)
                print('ep', ep, 'i', i, 'test psem', ls_psemce, 'bbvert', ls_bbvert_all, 'l2', ls_bbvert_l2, 'ce',
                      ls_bbvert_ce, 'siou', ls_bbvert_iou, 'bbscore', ls_bbscore, 'pmask', ls_pmask)
                print('test pred bborder', pred_bborder)

            ###### saving model
            # if i == total_train_batch_num - 1 or i == 0:
            #     net.saver.save(net.sess, save_path=net.train_mod_dir + 'model.cptk')
            #     print("ep", ep, " i", i, " model saved!")
            # if ep % 5 == 0 and i == total_train_batch_num - 1:
            #     net.saver.save(net.sess, save_path=net.train_mod_dir + 'model' + str(ep).zfill(3) + '.cptk')
            if i == total_train_batch_num - 1 or i == 0:
                net.saver.save(net.sess, save_path=os.path.join(FLAGS.output_path, 'model.cptk'))
                print("ep", ep, " i", i, " model saved!")
            if ep % 5 == 0 and i == total_train_batch_num - 1:
                net.saver.save(net.sess, save_path=os.path.join(FLAGS.output_path,
                                                                    'model' + str(ep).zfill(3) + '.cptk'))

            ###### full eval, if needed
            if ep % 5 == 0 and i == total_train_batch_num - 1:
                from main_eval import Evaluation
                result_path = './log/test_res/' + str(ep).zfill(3) + '_' + test_areas[0] + '/'
                Evaluation.ttest(net, data, result_path, test_batch_size=4)
                Evaluation.evaluation(dataset_path, train_areas, result_path)
                # mox.file.copy_parallel(result_path, train_url)
                print('full eval finished!')
        # shutil.copyfile(net.train_mod_dir, train_mod_url)
        # shutil.copyfile(net.test_sum_dir, test_sum_url)
        # shutil.copyfile(net.train_sum_dir, train_sum_url)
        # shutil.copyfile(result_path, train_result_url)


############
if __name__ == '__main__':
    os.system("cd /home/ma-user/modelarts/outputs/train_url_0/")
    import os
    from main_3D_BoNet import BoNet
    from helper_data_s3dis import Data_Configs as Data_Configs

    configs = Data_Configs()
    net = BoNet(configs=configs)
    net.creat_folders(name='log', re_train=False)
    # net.build_graph()

    ####
    from helper_data_s3dis import Data_S3DIS as Data

    train_areas = ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
    test_areas = ['Area_5']

    dataset_path = FLAGS.data_path
    train_url = FLAGS.output_path
    
    # train_mod_url = os.path.join(FLAGS.output_path, "npu_train_mod_url")
    # train_sum_url = os.path.join(FLAGS.output_path, "npu_train_sum_url")
    # test_sum_url = os.path.join(FLAGS.output_path, "npu_test_sum_url")
    # train_result_url = os.path.join(FLAGS.output_path, "npu_train_result_url")

    # if not os.path.exists(train_mod_url):
    #     os.mkdir(train_mod_url)
    # 
    # if not os.path.exists(train_sum_url):
    #     os.mkdir(train_sum_url)
    # 
    # if not os.path.exists(test_sum_url):
    #     os.mkdir(test_sum_url)
    # 
    # if not os.path.exists(train_result_url):
    #     os.mkdir(train_result_url)

    
    train_batch_size = 4
    data = Data(dataset_path, train_areas, test_areas, train_batch_size)
    bat_pc, _, _, bat_psem_onehot, bat_bbvert, bat_pmask = data.load_train_next_batch()

    net.build_graph(tf.convert_to_tensor(bat_pc[:, :, 0:9]), tf.convert_to_tensor(bat_bbvert),
                    tf.convert_to_tensor(bat_pmask), tf.convert_to_tensor(bat_psem_onehot))

    train(net, data)
